load("//bazel:api.bzl", "mojo_filecheck_test", "mojo_test")
load("//bazel:config.bzl", "DEFAULT_GPU_MEMORY")

package(default_visibility = ["//max:consumers"])

_EXTRA_CONSTRAINTS = {
    "test_async_copy_wgmma.mojo": ["//:h100_gpu"],
    "test_dual_gemm.mojo": ["//:nvidia_gpu"],  # FIXME: KERN-1377
    "test_gemm_kernel.mojo": ["//:nvidia_gpu"],  # FIXME: KERN-1377
    "test_gemv2.mojo": select({
        "//:h100_gpu": ["@platforms//:incompatible"],
        "//:apple_gpu": ["@platforms//:incompatible"],  # FIXME: PRDT-506
        "//conditions:default": [],
    }),
    "test_gemv_bf16.mojo": ["//:nvidia_gpu"],  # FIXME: KERN-1377
    "test_gemv_tma.mojo": select({
        "//:h100_gpu": ["//:h100_gpu"],
        "//:b200_gpu": ["//:b200_gpu"],
        "//conditions:default": ["@platforms//:incompatible"],
    }),
    "test_grouped_matmul.mojo": select({
        "//:h100_gpu": ["//:h100_gpu"],
        "//:b200_gpu": ["//:b200_gpu"],
        "//conditions:default": ["@platforms//:incompatible"],
    }),
    "test_matmul_amd.mojo": ["//:amd_gpu"],  # TODO: Disabled FIXME: KERN-1377
    "test_structured_kernel_components.mojo": ["//:amd_gpu"],
    "test_matmul_selection_heuristic.mojo": ["//:a100_gpu"],
    "test_matmul_sm90_epilogue.mojo": ["//:h100_gpu"],
    "test_stream_k.mojo": ["//:nvidia_gpu"],
    "test_matmul_tile_scheduler_sm100.mojo": ["//:b200_gpu"],
    "_test_tma.mojo": [
        "//:b200_gpu",
        "//:h100_gpu",
    ],
    "test_tma_wgmma.mojo": ["//:h100_gpu"],
    "test_tma_wgmma_with_multicast.mojo": ["@platforms//:incompatible"],  # FIXME: KERN-2060
    "test_vendor_blas.mojo": [
        "//:nvidia_gpu",
        "//:has_multi_gpu",
    ],
    "test_matmul_sm90_fp8.mojo": ["//:h100_gpu"],
    "test_matmul_sm90_bf16_multicast.mojo": ["//:h100_gpu"],
    "test_matmul_sm90_hilbert.mojo": ["//:h100_gpu"],
    "test_matmul_sm90_bf16.mojo": ["//:h100_gpu"],
    "test_matmul_sm90_deepseek_scheduler.mojo": ["//:h100_gpu"],
    "test_sm90_splitk.mojo": ["//:h100_gpu"],
    "test_multistage_gemm_kernel.mojo": select({
        "//:apple_gpu": ["@platforms//:incompatible"],
        "//conditions:default": [],
    }),  # FIXME: long running on MacOS
    "test_linalg_matmul_gpu.mojo": select({
        "//:apple_gpu": ["@platforms//:incompatible"],
        "//conditions:default": [],
    }),  # FIXME: MOCO-2366
    "test_gemv.mojo": select({
        "//:apple_gpu": ["@platforms//:incompatible"],
        "//conditions:default": [],
    }),  # FIXME: MOCO-2397
    "test_grouped_matmul_vendor.mojo": select({
        "//:apple_gpu": ["@platforms//:incompatible"],
        "//conditions:default": [],
    }),  # FIXME: PRDT-506
    "test_overlap_matmul_allreduce.mojo": select({
        "//:apple_gpu": ["@platforms//:incompatible"],
        "//conditions:default": [],
    }),  # FIXME: PRDT-506
    "test_batched_matmul.mojo": select({
        "//:apple_gpu": ["@platforms//:incompatible"],
        "//conditions:default": [],
    }),  # FIXME: PRDT-506
    "test_matmul_custom.mojo": select({
        "//:apple_gpu": ["@platforms//:incompatible"],
        "//conditions:default": [],
    }),  # FIXME: PRDT-506
    "test_matmul.mojo": select({
        "//:apple_gpu": ["@platforms//:incompatible"],
        "//conditions:default": [],
    }),  # FIXME: PRDT-506
    "test_matmul_tile_scheduler.mojo": select({
        "//:apple_gpu": ["@platforms//:incompatible"],
        "//conditions:default": [],
    }),  # FIXME: MOCO-2405
    "test_grouped_matmul_tile_scheduler.mojo": select({
        "//:apple_gpu": ["@platforms//:incompatible"],
        "//conditions:default": [],
    }),
    "test_naive_blockwise_fp8_matmul.mojo": select({
        "//:h100_gpu": ["//:h100_gpu"],
        "//:b200_gpu": ["//:b200_gpu"],
        "//conditions:default": ["@platforms//:incompatible"],
    }),
    "test_cublaslt_scaled_mxfp8_matmul.mojo": select({
        "//:b200_gpu": ["//:b200_gpu"],
        "//conditions:default": ["@platforms//:incompatible"],
    }),
    "test_cublaslt_scaled_nvfp4_matmul.mojo": select({
        "//:b200_gpu": ["//:b200_gpu"],
        "//conditions:default": ["@platforms//:incompatible"],
    }),
}

_FILECHECK_TESTS = [
    "test_block_swizzle.mojo",
    "test_matmul_selection_heuristic.mojo",
    "test_matmul_tile_scheduler.mojo",
    "test_matmul_tile_scheduler_sm100.mojo",
    "test_grouped_matmul_tile_scheduler.mojo",
]

_ASM_CHECK_TESTS = [
    "test_matmul_sm100_ptx.mojo",
]

_B200_TESTS = [
    "test_tma_mma_sm100.mojo",
    "test_tma_pair_mma_sm100.mojo",
    "test_matmul_sm100_fallback.mojo",
    "test_matmul_sm100_composable.mojo",
    "test_matmul_sm100_blockwise_fp8.mojo",
    "test_grouped_matmul_dynamic_scaled_fp8.mojo",
    "test_matmul_sm100_1sm_blockwise_fp8.mojo",
    "test_matmul_sm100_2sm_blockwise_fp8.mojo",
    "test_grouped_matmul_sm100_blockwise_fp8.mojo",
    "test_batched_matmul_sm100_scaled.mojo",
    "test_matmul_sm100_epilogue.mojo",
    "test_matmul_sm100_swapAB_epilogue.mojo",
    "test_matmul_sm100_2sm_bf16.mojo",
    "test_matmul_sm100_1sm_bf16.mojo",
    "test_matmul_sm100_2sm_fp8.mojo",
    "test_matmul_sm100_1sm_fp8.mojo",
    "test_matmul_sm100_splitk_2sm_bf16.mojo",
    "test_tma_mma_sm100_mxfp8.mojo",
    "test_tma_pair_mma_sm100_mxfp8.mojo",
]

_GPU_MEMORY = {
    "test_batched_matmul.mojo": "4",
    "test_dual_gemm.mojo": "2",
    "test_gemv2.mojo": "2",
    "test_matmul.mojo": "3",
    "test_matmul_custom.mojo": "4",
    "test_matmul_sm90_bf16_multicast.mojo": "2",
    "test_matmul_sm90_deepseek_scheduler.mojo": "1",
    "test_matmul_sm90_epilogue.mojo": "1",
    "test_matmul_sm90_fp8.mojo": "1",
    "test_matmul_sm90_hilbert.mojo": "1",
    "test_sm90_splitk.mojo": "1.5",
    "test_vendor_blas.mojo": "3",
    "test_warp_specialization_gemm_with_cluster.mojo": "2",
    "test_matmul_sm100_bf16.mojo": "3",
    "test_matmul_sm100_fp8.mojo": "3",
    "test_matmul_sm100_epilogue.mojo": "3",
    "test_overlap_matmul_allreduce.mojo": "1.5",
}

[
    mojo_filecheck_test(
        name = src + ".test",
        size = "large",
        srcs = [src],
        exec_properties = {
            "test.resources:gpu-memory": _GPU_MEMORY.get(src, DEFAULT_GPU_MEMORY),
        },
        tags = ["gpu"],
        target_compatible_with = ["//:has_gpu"] + _EXTRA_CONSTRAINTS.get(src, []),
        deps = [
            "//max:internal_utils",
            "//max:linalg",
            "@mojo//:stdlib",
        ],
    )
    for src in _FILECHECK_TESTS
]

[
    mojo_test(
        name = src + ".test",
        size = "large",
        srcs = [src],
        data = ["//max/kernels/test/testdata:ptx_files"],
        enable_assertions = False,  # PTX comparison assertions disabled
        deps = [
            "//max:linalg",
            "//max:testdata",
            "@mojo//:stdlib",
            "@mojo//:test_utils",
        ],
    )
    for src in _ASM_CHECK_TESTS
]

[
    mojo_test(
        name = src + ".test",
        size = "large",
        srcs = [src],
        exec_properties = {
            "test.resources:gpu-memory": _GPU_MEMORY.get(src, DEFAULT_GPU_MEMORY),
        },
        tags = [
            "b200",
            "gpu",
        ],
        target_compatible_with = [
            "//:has_gpu",
            "//:b200_gpu",
        ],
        deps = [
            "//max:internal_utils",
            "//max:linalg",
            "@mojo//:stdlib",
            "@mojo//:test_utils",
        ],
    )
    for src in _B200_TESTS
]

[
    mojo_test(
        name = src + ".test",
        size = "large",
        srcs = [src],
        enable_assertions = not src in [
            "test_warp_specialization_gemm.mojo",
            "test_generated_sm90_matmul_ptx.mojo",
        ],  # TODO: Fix assertions exception
        exec_properties = {
            "test.resources:gpu-memory": _GPU_MEMORY.get(src, DEFAULT_GPU_MEMORY),
        },
        tags = ["gpu"] + (
            ["manual"] if src in "test_multistage_gemm_fp8.mojo" else []  #  TODO: KERN-1480 Fix and remove this tag
        ),
        target_compatible_with = ["//:has_gpu"] + _EXTRA_CONSTRAINTS.get(src, []),
        deps = [
            "//max:internal_utils",
            "//max:linalg",
            "@mojo//:stdlib",
            "@mojo//:test_utils",
        ],
    )
    for src in glob(
        ["**/*.mojo"],
        exclude = _FILECHECK_TESTS + _ASM_CHECK_TESTS + _B200_TESTS,
    )
]

filegroup(
    name = "test-sources",
    srcs = glob(["**/*.mojo"]),
    visibility = ["//utils/debugging/gpu-build-benchmarking:__subpackages__"],
)
